#!/usr/bin/env python3
"""
Extract inner-most spans and their bounding boxes, and the mathML output, 
from rendered LaTeX equations using Playwright and KaTeX.
Caching is maintained via a SHA1-based hash stored as a JSON file.

Requirements:
    pip install playwright
    python -m playwright install chromium

    Place katex.min.css and katex.min.js in the same directory as this script
"""

import os
import hashlib
import pathlib
import json
import re
import shutil
from dataclasses import dataclass
from typing import List
import unittest
import xml.etree.ElementTree as ET

from playwright.sync_api import sync_playwright, Error as PlaywrightError

@dataclass
class BoundingBox:
    x: float
    y: float
    width: float
    height: float

@dataclass
class SpanInfo:
    text: str
    bounding_box: BoundingBox

@dataclass
class RenderedEquation:
    mathml: str
    spans: List[SpanInfo]

def get_equation_hash(equation, bg_color="white", text_color="black", font_size=24):
    """
    Calculate SHA1 hash of the equation string and rendering parameters.
    """
    params_str = f"{equation}|{bg_color}|{text_color}|{font_size}"
    return hashlib.sha1(params_str.encode('utf-8')).hexdigest()

def get_cache_dir():
    """
    Get the cache directory for equations, creating it if it doesn't exist.
    """
    cache_dir = pathlib.Path.home() / '.cache' / 'olmocr' / 'bench' / 'equations'
    cache_dir.mkdir(parents=True, exist_ok=True)
    return cache_dir


def clear_cache_dir():
    """
    Clear all files and subdirectories in the cache directory.
    """
    cache_dir = get_cache_dir()
    if cache_dir.exists() and cache_dir.is_dir():
        shutil.rmtree(cache_dir)
        cache_dir.mkdir(parents=True, exist_ok=True)  # Recreate the empty directory


def render_equation(
    equation, 
    bg_color="white",
    text_color="black",
    font_size=24,
    use_cache=True,
    debug_dom=False,
):
    """
    Render a LaTeX equation using Playwright and KaTeX, extract the inner-most span elements
    (those without child elements that contain non-whitespace text) along with their bounding boxes,
    and also extract the MathML output generated by KaTeX.

    Returns:
        RenderedEquation: A dataclass containing the mathml string and a list of SpanInfo dataclasses.
    """
    # Calculate hash for caching
    eq_hash = get_equation_hash(equation, bg_color, text_color, font_size)
    cache_dir = get_cache_dir()
    cache_file = cache_dir / f"{eq_hash}.json"
    cache_error_file = cache_dir / f"{eq_hash}_error"
    
    if use_cache:
        if cache_error_file.exists():
            return None
        if cache_file.exists():
            with open(cache_file, 'r') as f:
                data = json.load(f)
            spans = [
                SpanInfo(
                    text=s["text"],
                    bounding_box=BoundingBox(
                        x=s["boundingBox"]["x"],
                        y=s["boundingBox"]["y"],
                        width=s["boundingBox"]["width"],
                        height=s["boundingBox"]["height"],
                    )
                )
                for s in data["spans"]
            ]
            return RenderedEquation(mathml=data["mathml"], spans=spans)

    # Escape backslashes for JavaScript string
    escaped_equation = json.dumps(equation)
    
    # Get local paths for KaTeX files
    script_dir = os.path.dirname(os.path.abspath(__file__))
    katex_css_path = os.path.join(script_dir, "katex.min.css")
    katex_js_path = os.path.join(script_dir, "katex.min.js")
    
    if not os.path.exists(katex_css_path) or not os.path.exists(katex_js_path):
        raise FileNotFoundError(f"KaTeX files not found. Please ensure katex.min.css and katex.min.js are in {script_dir}")
    
    with sync_playwright() as p:
        browser = p.chromium.launch()
        page = browser.new_page(viewport={"width": 800, "height": 400})
        
        # Basic HTML structure
        html = f"""
        <!DOCTYPE html>
        <html>
        <head>
            <style>
                body {{
                    display: flex;
                    justify-content: center;
                    align-items: center;
                    height: 100vh;
                    margin: 0;
                    background-color: {bg_color};
                    color: {text_color};
                }}
                #equation-container {{
                    padding: 0;
                    font-size: {font_size}px;
                }}
            </style>
        </head>
        <body>
            <div id="equation-container"></div>
        </body>
        </html>
        """
        page.set_content(html)
        page.add_style_tag(path=katex_css_path)
        page.add_script_tag(path=katex_js_path)
        page.wait_for_load_state("networkidle")
        
        katex_loaded = page.evaluate("typeof katex !== 'undefined'")
        if not katex_loaded:
            raise RuntimeError("KaTeX library failed to load. Check your katex.min.js file.")
        
        try:
            error_message = page.evaluate(f"""
            () => {{
                try {{
                    katex.render({escaped_equation}, document.getElementById("equation-container"), {{
                        displayMode: true,
                        throwOnError: true
                    }});
                    return null;
                }} catch (error) {{
                    console.error("KaTeX error:", error.message);
                    return error.message;
                }}
            }}
            """)
        except PlaywrightError as ex:
            print(escaped_equation)
            error_message = str(ex)
            raise 

        if error_message:
            print(f"Error rendering equation: '{equation}'")
            print(error_message)
            cache_error_file.touch()
            browser.close()
            return None
        
        page.wait_for_selector(".katex", state="attached")
        
        if debug_dom:
            katex_dom_html = page.evaluate("""
            () => {
                return document.getElementById("equation-container").innerHTML;
            }
            """)
            print("\n===== KaTeX DOM HTML =====")
            print(katex_dom_html)
        
        # Extract inner-most spans with non-whitespace text
        spans_info = page.evaluate("""
        () => {
            const spans = Array.from(document.querySelectorAll('span'));
            const list = [];
            spans.forEach(span => {
                // Check if this span has no child elements and contains non-whitespace text
                if (span.children.length === 0 && /\S/.test(span.textContent)) {
                    const rect = span.getBoundingClientRect();
                    list.push({
                        text: span.textContent.trim(),
                        boundingBox: {
                            x: rect.x,
                            y: rect.y,
                            width: rect.width,
                            height: rect.height
                        }
                    });
                }
            });
            return list;
        }
        """)
        
        if debug_dom:
            print("\n===== Extracted Span Information =====")
            print(spans_info)
        
        # Extract mathML output (if available) from the KaTeX output.
        # We try to get the <math> element within an element with class "katex-mathml".
        mathml = page.evaluate("""
        () => {
            const mathElem = document.querySelector('.katex-mathml math');
            return mathElem ? mathElem.outerHTML : "";
        }
        """)
        
        browser.close()
        
        # Build the result as a RenderedEquation dataclass
        rendered_eq = RenderedEquation(
            mathml=mathml,
            spans=[
                SpanInfo(
                    text=s["text"],
                    bounding_box=BoundingBox(
                        x=s["boundingBox"]["x"],
                        y=s["boundingBox"]["y"],
                        width=s["boundingBox"]["width"],
                        height=s["boundingBox"]["height"]
                    )
                )
                for s in spans_info
            ]
        )
        
        # Save to cache (convert dataclasses to a JSON-serializable dict)
        cache_data = {
            "mathml": rendered_eq.mathml,
            "spans": [
                {
                    "text": span.text,
                    "boundingBox": {
                        "x": span.bounding_box.x,
                        "y": span.bounding_box.y,
                        "width": span.bounding_box.width,
                        "height": span.bounding_box.height
                    }
                }
                for span in rendered_eq.spans
            ]
        }
        with open(cache_file, 'w') as f:
            json.dump(cache_data, f)
        return rendered_eq


def compare_rendered_equations(haystack: RenderedEquation, needle: RenderedEquation) -> bool:
    """
    Compare two rendered equations by cleaning the MathML (removing namespaces),
    extracting the inner content of any <semantics> element (ignoring <annotation>),
    normalizing whitespace, and checking if the needle's inner MathML is a substring
    of the haystack's inner MathML.
    """

    def strip_namespaces(elem: ET.Element) -> ET.Element:
        """
        Recursively remove namespace prefixes from an ElementTree element.
        """
        for sub in elem.iter():
            if '}' in sub.tag:
                sub.tag = sub.tag.split('}', 1)[1]
        return elem

    def extract_inner(mathml: str) -> str:
        """
        Parse the MathML, remove namespaces, and if a <semantics> element exists,
        concatenate the string representations of its children (except <annotation>).
        Otherwise, return the whole cleaned MathML.
        """
        try:
            root = ET.fromstring(mathml)
            root = strip_namespaces(root)
            semantics = root.find('semantics')
            if semantics is not None:
                inner_parts = []
                for child in semantics:
                    if child.tag != 'annotation':
                        inner_parts.append(ET.tostring(child, encoding='unicode'))
                return ''.join(inner_parts)
            else:
                return ET.tostring(root, encoding='unicode')
        except Exception as e:
            # For debugging purposes, print the error
            print("Error parsing MathML:", e)
            return mathml

    def normalize(s: str) -> str:
        """
        Remove all whitespace from the string.
        """
        return re.sub(r'\s+', '', s)

    # Clean and extract the inner MathML for both haystack and needle.
    haystack_inner = normalize(extract_inner(haystack.mathml))
    needle_inner = normalize(extract_inner(needle.mathml))

    # # For debugging: print the cleaned MathML strings.
    # print("Cleaned haystack MathML:", haystack_inner)
    # print("Cleaned needle MathML:", needle_inner)

    # If needle is longer than haystack, swap them.
    if len(needle_inner) > len(haystack_inner):
        needle_inner, haystack_inner = haystack_inner, needle_inner

    return needle_inner in haystack_inner

class TestRenderedEquationComparison(unittest.TestCase):
    def test_exact_match(self):
        # Both calls with identical LaTeX should produce matching MathML output.
        eq1 = render_equation("a+b", use_cache=False)
        eq2 = render_equation("a+b", use_cache=False)
        self.assertTrue(compare_rendered_equations(eq1, eq2))
    
    def test_whitespace_difference(self):
        # Differences in whitespace in the LaTeX input should not affect the MathML output.
        eq1 = render_equation("a+b", use_cache=False)
        eq2 = render_equation("a + b", use_cache=False)
        self.assertTrue(compare_rendered_equations(eq1, eq2))
    
    def test_not_found(self):
        # Completely different equations should not match.
        eq1 = render_equation("c-d", use_cache=False)
        eq2 = render_equation("a+b", use_cache=False)
        self.assertFalse(compare_rendered_equations(eq1, eq2))
    
    def test_align_block_contains_needle(self):
        # The MathML output of the plain equation should be found within the align block output.
        eq_plain = render_equation("a+b", use_cache=False)
        eq_align = render_equation("\\begin{align*}a+b\\end{align*}", use_cache=False)
        self.assertTrue(compare_rendered_equations(eq_align, eq_plain))
    
    def test_align_block_needle_not_in(self):
        # An align block rendering a different equation should not contain the MathML of an unrelated equation.
        eq_align = render_equation("\\begin{align*}a+b\\end{align*}", use_cache=False)
        eq_diff = render_equation("c-d", use_cache=False)
        self.assertFalse(compare_rendered_equations(eq_align, eq_diff))

    def test_big(self):
        ref_rendered = render_equation("\\nabla \\cdot \\mathbf{E} = \\frac{\\rho}{\\varepsilon_0}", use_cache=False, debug_dom=False)
        align_rendered = render_equation("""\\begin{align*}\\nabla \\cdot \\mathbf{E} = \\frac{\\rho}{\\varepsilon_0}\\end{align*}""", use_cache=False, debug_dom=False)
        self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))

if __name__ == "__main__":
    unittest.main()
